#%%
import math

####resnet18
dims = [[64, 64, 3, 3], [64, 64, 3, 3], [64, 64, 3, 3], [64, 64, 3, 3], [64, 128, 3, 3], [128, 128, 3, 3],
        [64, 128, 1, 1], [128, 128, 3, 3], [128, 128, 3, 3], [128, 256, 3, 3], [256, 256, 3, 3], [128, 256, 1, 1],
        [256, 256, 3, 3], [256, 256, 3, 3], [256, 512, 3, 3], [512, 512, 3, 3], [256, 512, 1, 1], [512, 512, 3, 3],
        [512, 512, 3, 3]]   


ranks = [[19, 19, 3, 3], [19, 19, 3, 3], [19, 19, 3, 3], [19, 19, 3, 3], [17, 17, 1, 1], [27, 19, 3, 3], [30, 26, 3, 3], [30, 27, 3, 3], [29, 29, 3, 3], [24, 24, 1, 1], [32, 29, 3, 3], [30, 29, 3, 3], [30, 26, 3, 3], [32, 32, 3, 3], [31, 31, 1, 1], [33, 33, 3, 3], [31, 32, 3, 3], [32, 31, 3, 3], [35, 35, 3, 3]]

def compute_cr(ranks_list,params_compressed = None):
    if params_compressed == None:
        total_params_full = 0.
        total_params_compressed = 0.
        for dim,rank in zip(dims,ranks_list):
            total_params_compressed+= dim[0]*rank[0]+dim[1]*rank[1] + dim[2]*rank[2]+dim[3]*rank[3] + math.prod(rank)
            total_params_full+=math.prod(dim)
        print(f'check {len(ranks)==len(dims)}')
        print(f'total params model {total_params_full}, compressed {total_params_compressed}, cr test {(1.-total_params_compressed/total_params_full)}')